import itertools
import json
import os

import numpy as np
# import seaborn as sns
# import matplotlib.pyplot as plt
# import collections
# old constants
from datetime import datetime

from tfFunctionsUtils import getdoKey, generate_permutations


class Experiment:

    def __init__(self, set_truedag, **kwargs):

        self.PROJECT_NAME = kwargs.get('PROJECT_NAME', 'XrayImage')
        self.exp_name = kwargs.get('exp_name', self.PROJECT_NAME)
        self.PLOTS_PER_EPOCH = 1


        self.tvd_diff= {}


        self.NOISE_DIM = kwargs.get('NOISE_DIM', 128)
        self.CONF_NOISE_DIM = kwargs.get('CONF_NOISE_DIM', 128)
        self.generator_decay=1e-6
        self.discriminator_decay=1e-6
        self.IMAGE_NOISE_DIM = kwargs.get('IMAGE_NOISE_DIM', 100)
        self.IMAGE_FILTERS = kwargs.get('IMAGE_FILTERS', [128, 64, 32])
        self.IMAGE_SIZE =  kwargs.get('IMAGE_SIZE', 32)
        self.ENCODED_DIM =  kwargs.get('ENCODED_DIM', 10)

        self.obs_state = kwargs.get('obs_state', 2)

        self.G_hid_dims = kwargs.get('G_hid_dims')  # in_d1  dn_out
        self.D_hid_dims = kwargs.get('D_hid_dims')  # 3x10x5x1
        # G_hid_dims=[10, 25, 25, 10],
        # D_hid_dims=[10, 15, 10, 5],

        # for ett non id
        # G_hid_dims=[30,40,30,20,10],
        # D_hid_dims= [20, 30, 20, 10, 5]

        # G_hid_dims=[30,60,90,60,30,15],
        # D_hid_dims=[20,30,60,30,20,10],

        self.CRITIC_ITERATIONS = kwargs.get('CRITIC_ITERATIONS', 5)
        self.LAMBDA_GP = kwargs.get('LAMBDA_GP', 0.1)  # It was 0.3

        self.learning_rate = kwargs.get('learning_rate', 2 * 1e-5)
        self.betas = (0.5, 0.9)
        self.Synthetic_Sample_Size = kwargs.get('Synthetic_Sample_Size', 20000)
        self.intv_Sample_Size = kwargs.get('intv_Sample_Size', 20000)
        self.ex_row_size = kwargs.get('ex_row_size', 20)
        self.batch_size = kwargs.get('batch_size', 100)  # from 256
        self.intv_batch_size = kwargs.get('intv_batch_size', 100)  # from 256
        self.num_epochs =  kwargs.get('num_epochs', 300)
        self.STOPAGE1 = 50
        self.STOPAGE2 = 20000
        self.lr_dec = 1

        self.curr_epoochs = 0
        self.curr_iter = 0

        # gumbel-softmax
        self.temp_min = kwargs.get('temp_min', 0.00001)
        self.ANNEAL_RATE = 0.00003
        self.start_temp = kwargs.get('Temperature', 0.5)
        self.Temperature = kwargs.get('Temperature', 0.5)

        self.dataset_activated = kwargs.get('dataset_activated', [0])

        # Data_intervs=[{}, {"X1":1,"W":1}, {"X1":1,"W":0}, {"X1":0,"X2":0}]

        self.SAVE_MODEL = True
        self.LOAD_MODEL = False
        self.LOAD_TRAINED_CONTROLLER = False
        self.load_which_models={}
        self.pre_trained_by_others = []
        self.checkpoints = {}

        # self.DEVICE = get_freer_gpu()
        # self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        now = datetime.now()
        self.curDATE = now.strftime("%b_%d_%Y")
        self.curTIME = now.strftime("%H_%M")


        dlist=[]
        self.Data_intervs = kwargs.get('Data_intervs', dlist)
        self.Data_observs = kwargs.get('Data_observs', [])
        self.num_datasets = len(self.Data_intervs)


        self.G_avg_losses = []
        self.D_avg_losses = []

        # scm ground truth
        self.noise_states = kwargs.get('noise_states', 8)
        self.latent_state = kwargs.get('latent_state', 8)
        self.dist_thresh = kwargs.get('dist_thresh', 0.2)
        self.allowed_noise = kwargs.get('allowed_noise', 0.50)

        self.causal_hierarchy = kwargs.get('causal_hierarchy', 1)

        # self.evaluate_after_epochs = kwargs.get('sachsEvaluation', None)


        ret = set_truedag(self.noise_states, self.latent_state, self.obs_state, self.Data_intervs)
        self.DAG_desc, self.Complete_DAG_desc, self.Complete_DAG, self.complete_labels, self.Observed_DAG, self.label_names, self.image_labels, self.rep_labels, self.interv_queries, self.cf_queries, self.latent_conf, \
        self.confTochild, self.exogenous, self.cf_intervene, self.cf_observe, self.cf_evidence, self.cflabel_names, self.twin_map, self.Twin_Network, self.cf_exogenous, \
        self.noise_params, self.train_mech_dict, self.label_dim, self.plot_title \
            = ret


        self.true_bn = kwargs.get('true_bn', None)
        self.features= kwargs.get('features', ["digit", "thickness", "color"])



        self.cf_samples = self.Synthetic_Sample_Size
        self.num_labels = len(self.label_names)

        # main_path= kwargs.get('main_path', f"/path_to_project/{self.PROJECT_NAME}/SAVED_EXPERIMENTS/")
        main_path= kwargs.get('main_path', f"./SAVED_EXPERIMENTS/")

        # saving model and results

        self.new_experiment= kwargs.get('new_experiment', True)


        if self.new_experiment == True:
            os.makedirs(main_path + self.Complete_DAG_desc ,exist_ok=True)
            saved_path = main_path + self.Complete_DAG_desc + "/" + self.exp_name+"/"+ self.curDATE + "-" + self.curTIME
            self.SAVED_PATH = kwargs.get('SAVED_PATH', saved_path)
            self.LOAD_MODEL_PATH = kwargs.get('LOAD_MODEL_PATH', self.SAVED_PATH)

            # saving scm
            # /path_to_project/SAVED_EXPERIMENTS/mnist_addition_graph/SCMs/Exp1.txt
            os.makedirs(main_path + self.Complete_DAG_desc+"/SCMs" ,exist_ok=True)


            # saving dataset
            os.makedirs(main_path + self.Complete_DAG_desc+"/preprocessed_dataset/",exist_ok=True)

            INSTANCES = {}
            INSTANCES["last_exp"] = self.SAVED_PATH
            with open(main_path + self.Complete_DAG_desc+"/SHARED_INFO.txt", 'w') as fp:
                fp.write(json.dumps(INSTANCES))


        scm_path = main_path + self.Complete_DAG_desc + "/SCMs/" + self.exp_name + ".txt"
        self.SCM_PATH = kwargs.get('SCM_PATH', scm_path)

        self.Intv_SCMs= main_path + self.Complete_DAG_desc + "/SCMs/interventions/"
        self.Cf_SCMs= main_path + self.Complete_DAG_desc + "/SCMs/counterfactuals/"

        self.file_roots = main_path + self.Complete_DAG_desc + "/preprocessed_dataset/"


        self.isJoint = False
        self._data_sampler = None
        self.test_marginals=False
        self.bayesNet= None


    def anneal_temperature(self, tot_iters):

        # if (tot_iters) % 100 == 1:
        self.Temperature = np.maximum(self.Temperature * np.exp(-self.ANNEAL_RATE * tot_iters), self.temp_min)
        print(tot_iters, ":Temperature", self.Temperature)

        # self.learning_rate = np.maximum(self.learning_rate * np.exp(-self.ANNEAL_RATE/10 * tot_iters),1e-4)





class CausalGraph():

    def __init__(self, name, dag, confs, dims, num_latent):
        self.DAG_desc = name

        self.Complete_DAG_desc = name
        self.Observed_DAG = dag

        self.num_confs = len(confs.keys())
        self.Complete_DAG = {}
        for cnf in range(self.num_confs):
            self.Complete_DAG["U" + str(cnf)] = []

        self.latent_conf = {}
        for var in self.Observed_DAG:
            self.Complete_DAG[var] = []
            self.latent_conf[var] = []

        self.confTochild = confs

        for cnf in self.confTochild:
            for var in self.confTochild[cnf]:
                self.latent_conf[var].append(cnf)
                self.Complete_DAG[var].append(cnf)

        for var in self.Observed_DAG:
            self.Complete_DAG[var] = self.Complete_DAG[var] + self.Observed_DAG[var]

        self.complete_labels = list(self.Complete_DAG.keys())
        self.label_names = list(self.Observed_DAG.keys())

        self.label_dim=dims

        for cnf in self.confTochild:
            self.label_dim[cnf] = num_latent


        self.image_labels= None
        self.rep_labels= None



def set_Xray(noise_states, latent_state, obs_state, Data_intervs):


    Observed_DAG = {
        "covid_19": [],
        "xray": ['covid_19'],
        "pneum": ['xray'],
        "Rxray": ["xray"],
    }

    confTochild = {"U0": ["covid_19", "pneum"]}

    label_dim = {"covid_19": 2, 'xray': 0, "pneum": 2, 'Rxray': 100}   # Issue here with Images dimension. And also difmensiona of encoder RI
    G = CausalGraph(name="xrayImage", dag=Observed_DAG, confs=confTochild, dims=label_dim,
                    num_latent=latent_state)

    plot_title = "covid_19x CXR-3 experiment"
    #
    G.image_labels = ["xray"]
    G.rep_labels = ["Rxray"]


    intervention_list = [{"expr":"P(pneum,covid_19)" ,"obs":['covid_19','pneum'], "inter_vars":[]},
        {"expr":"P(pneum|do(covid_19))" ,"obs":['pneum'], "inter_vars":['covid_19']}
                         ]

    for lid in range(len(intervention_list)):
        intervention_list[lid]["expr"] = getdoKey(intervention_list[lid]["obs"], intervention_list[lid]["inter_vars"])

    interv_queries = []
    for intervention in intervention_list:
        perms = generate_permutations([label_dim[lb] for lb in intervention["inter_vars"]])
        key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
        interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})


    cf_queries = []


    exogenous = {}
    for label in G.label_names:
        if label not in G.image_labels:
            exogenous[label] = "n" + label


    # counterfactual variables
    cflabel_names = []
    Twin_Network = {}
    cf_exogenous = {}
    cf_intervene = {}
    cf_observe = []
    cf_evidence = {}
    twin_map = {}

    noise_params = {}
    for label in Observed_DAG:
        noise_params["n" + label] = (0.5, noise_states)

    for conf in confTochild:
        noise_params[conf] = (0.1, latent_state)


    train_mech_dict={}
    train_mech_dict["covid_19"] = [{'parents': [], 'intv': {}, 'compare': ['covid_19']}]
    train_mech_dict["xray"] = [{'parents': ['covid_19'], 'intv': {'covid_19'}, 'compare': ['xray']}]
    train_mech_dict["pneum"] = [{'parents': ['xray'], 'intv': {'xray'}, 'compare': [ 'pneum']}]
    train_mech_dict["Rxray"] = [{'parents': ['xray'], 'intv': {'xray'}, 'compare': ['Rxray']}]  # wish to train RI but fitting D,RI

    # train_mech_dict["covid_19"] = [{'parents': [], 'intv': {}, 'compare': ['covid_19',  'pneum', 'Rxray']}]
    # train_mech_dict["xray"] = [{'parents': ['covid_19'], 'intv': {}, 'compare': ['xray']}]
    # train_mech_dict["pneum"] = [{'parents': ['xray'], 'intv': {}, 'compare': ['covid_19','pneum', 'Rxray']}]
    # train_mech_dict["Rxray"] = [{'parents': ['xray'], 'intv': {}, 'compare': ['covid_19', 'Rxray']}]  #wish to train RI but fitting D,RI
    # #compare: joint for which variables are needed. parents: which variables i need to intervene on


    for label in Observed_DAG:
        if label not in G.image_labels:
            label_dim["n" + label] =  noise_states

    return G.DAG_desc, G.Complete_DAG_desc, G.Complete_DAG, G.complete_labels, G.Observed_DAG, G.label_names, G.image_labels, G.rep_labels, interv_queries, cf_queries, G.latent_conf, \
           G.confTochild, exogenous, cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous, \
           noise_params, train_mech_dict, G.label_dim, plot_title


